
package com.jstarcraft.ai.jsat.classifiers.bayesian.graphicalmodel;

import static java.lang.Math.exp;
import static java.lang.Math.log;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;

import com.jstarcraft.ai.jsat.classifiers.CategoricalData;
import com.jstarcraft.ai.jsat.classifiers.CategoricalResults;
import com.jstarcraft.ai.jsat.classifiers.ClassificationDataSet;
import com.jstarcraft.ai.jsat.classifiers.Classifier;
import com.jstarcraft.ai.jsat.classifiers.DataPoint;
import com.jstarcraft.ai.jsat.classifiers.DataPointPair;
import com.jstarcraft.ai.jsat.classifiers.bayesian.ConditionalProbabilityTable;
import com.jstarcraft.ai.jsat.classifiers.bayesian.NaiveBayes;
import com.jstarcraft.ai.jsat.exceptions.FailedToFitException;

import it.unimi.dsi.fastutil.ints.IntOpenHashSet;

/**
 * A class for representing a Baysian Network (BN) for discrete variables. A BN
 * use a graph to representing the relations between variables, and these links
 * are called the structure. The structure of a BN must be specified by an
 * expert using the {@link #depends(int, int) } method. The target class should
 * be specified as the parent of the variables which have a causal relationship
 * to it. These children of the target class should then have their own children
 * specified. Once the structure has been specified, the network can be trained
 * and used for classification. <br>
 * If the network structure has not been specified, or has no relationships for
 * the target class, the BN will create an edge from the target class to every
 * variable. If no edges were ever specified, this initialization of edges
 * corresponds to a {@link NaiveBayes} implementation.
 * 
 * @author Edward Raff
 */
public class DiscreteBayesNetwork implements Classifier {

    private static final long serialVersionUID = 2980734594356260141L;
    /**
     * The directed Graph that represents this BN
     */
    protected DirectedGraph<Integer> dag;
    /**
     * The Conditional probability tables for each variable
     */
    protected Map<Integer, ConditionalProbabilityTable> cpts;
    /**
     * The class we are predicting
     */
    protected CategoricalData predicting;
    /**
     * The prior probabilities of each class value
     */
    protected double[] priors;
    private boolean usePriors = DEFAULT_USE_PRIORS;

    /**
     * Whether or not the classifier should take into account the prior
     * probabilities. Default value is {@value #DEFAULT_USE_PRIORS}.
     */
    public static final boolean DEFAULT_USE_PRIORS = true;

    public DiscreteBayesNetwork() {
        dag = new DirectedGraph<Integer>();
    }

    public CategoricalResults classify(DataPoint data) {
        CategoricalResults cr = new CategoricalResults(predicting.getNumOfCategories());

        int classId = data.numCategoricalValues();
        // Use log proababilities to avoid underflow
        double logPSum = 0;
        double[] logProbs = new double[cr.size()];
        for (int i = 0; i < cr.size(); i++) {
            DataPointPair<Integer> dpp = new DataPointPair<Integer>(data, i);
            for (int classParent : dag.getChildren(classId))
                logProbs[i] += log(cpts.get(classParent).query(classParent, dpp));

            if (usePriors)
                logProbs[i] += log(priors[i]);
            logPSum += logProbs[i];
        }

        for (int i = 0; i < cr.size(); i++)
            cr.setProb(i, exp(logProbs[i] - logPSum));

        return cr;
    }

    /**
     * Adds a dependency relation ship between two variables that will be in the
     * network. The integer value corresponds the the index of the i'th categorical
     * variable, where the class target's value is the number of categorical
     * variables.
     * 
     * @param parent the parent variable, which will be explained in part by the
     *               child
     * @param child  the child variable, which contributes to the conditional
     *               probability of the parent.
     */
    public void depends(int parent, int child) {
        dag.addNode(child);
        dag.addNode(parent);
        dag.addEdge(parent, child);
    }

    public void train(ClassificationDataSet dataSet, boolean parallel) {
        train(dataSet);
    }

    public void train(ClassificationDataSet dataSet) {
        int classID = dataSet.getNumCategoricalVars();
        if (classID == 0)
            throw new FailedToFitException("Network needs categorical attribtues to work");

        predicting = dataSet.getPredicting();
        priors = dataSet.getPriors();
        cpts = new HashMap<Integer, ConditionalProbabilityTable>();
        IntOpenHashSet cptTrainSet = new IntOpenHashSet();

        if (dag.getNodes().isEmpty()) {
            for (int i = 0; i < classID; i++)
                depends(classID, i);
        }

        for (int classParent : dag.getChildren(classID)) {
            Set<Integer> depends = dag.getChildren(classParent);
            ConditionalProbabilityTable cpt = new ConditionalProbabilityTable();

            cptTrainSet.clear();
            cptTrainSet.addAll(depends);
            cptTrainSet.add(classParent);
            cptTrainSet.add(classID);
            cpt.trainC(dataSet, cptTrainSet);
            cpts.put(classParent, cpt);
        }
    }

    public boolean supportsWeightedData() {
        return false;
    }

    @Override
    public Classifier clone() {
        throw new UnsupportedOperationException("Not supported yet.");
    }

}
